{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model Components\n",
    "\n",
    "The 5 main components of a `WideDeep` model are:\n",
    "\n",
    "1. `Wide`\n",
    "2. `DeepDense`\n",
    "3. `DeepText`\n",
    "4. `DeepImage`\n",
    "5. `deephead`\n",
    "\n",
    "The first 4 of them will be collected and combined by the `WideDeep` collector class, while the 5th one can be optionally added to the `WideDeep` model through its corresponding parameters: `deephead` or alternatively `head_layers`, `head_dropout` and `head_batchnorm`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. Wide\n",
    "\n",
    "The wide component is a Linear layer \"plugged\" into the output neuron(s)\n",
    "\n",
    "The only particularity of our implementation is that we have implemented the linear layer via an Embedding layer plus a bias. While the implementations are equivalent, the latter is faster and far more memory efficient, since we do not need to one hot encode the categorical features. \n",
    "\n",
    "Let's assume we the following dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "from torch import nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>color</th>\n",
       "      <th>size</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>r</td>\n",
       "      <td>s</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>b</td>\n",
       "      <td>n</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>g</td>\n",
       "      <td>l</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  color size\n",
       "0     r    s\n",
       "1     b    n\n",
       "2     g    l"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.DataFrame({'color': ['r', 'b', 'g'], 'size': ['s', 'n', 'l']})\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "one hot encoded, the first observation would be"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "obs_0_oh = (np.array([1., 0., 0., 1., 0., 0.])).astype('float32')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "if we simply numerically encode (label encode or `le`) the values, starting from 1 (we will save 0 for padding, i.e. unseen values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "obs_0_le = (np.array([0, 3])).astype('int64')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "now, let's see if the two implementations are equivalent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# we have 6 different values. Let's assume we are performing a regression, so pred_dim = 1\n",
    "lin = nn.Linear(6, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "emb = nn.Embedding(6, 1) \n",
    "emb.weight = nn.Parameter(lin.weight.reshape_as(emb.weight))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-0.9452], grad_fn=<AddBackward0>)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lin(torch.tensor(obs_0_oh))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-0.9452], grad_fn=<AddBackward0>)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "emb(torch.tensor(obs_0_le)).sum() + lin.bias"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And this is precisely how the linear component `Wide` is implemented"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pytorch_widedeep.models import Wide"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [],
   "source": [
    "?Wide"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Wide(\n",
       "  (wide_linear): Embedding(11, 1, padding_idx=0)\n",
       ")"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "wide = Wide(wide_dim=10, pred_dim=1)\n",
    "wide"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that even though the input dim is 10, the Embedding layer has 11 weights. This is because we save 0 for padding, which is used for unseen values during the encoding process"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. DeepDense\n",
    "\n",
    "The `DeepDense` component is comprised by a stack of dense layers that receive the embedding representation of the categorical features concatenated with numerical continuous features. For those familiar with Fastai's tabular API, DeepDense is almost identical to their tabular model, although `DeepDense` allows for more flexibility when defining the embedding dimensions. Let's have a look to `DeepDense`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pytorch_widedeep.models import DeepDense"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# fake dataset\n",
    "X_deep = torch.cat((torch.empty(5, 4).random_(4), torch.rand(5, 1)), axis=1)\n",
    "colnames = ['a', 'b', 'c', 'd', 'e']\n",
    "embed_input = [(u,i,j) for u,i,j in zip(colnames[:4], [4]*4, [8]*4)]\n",
    "deep_column_idx = {k:v for v,k in enumerate(colnames)}\n",
    "continuous_cols = ['e']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "?DeepDense"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "deepdense = DeepDense(hidden_layers=[16,8], dropout=[0.5, 0.5], batchnorm=True, deep_column_idx=deep_column_idx,\n",
    "                      embed_input=embed_input, continuous_cols=continuous_cols)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DeepDense(\n",
       "  (embed_layers): ModuleDict(\n",
       "    (emb_layer_a): Embedding(4, 8)\n",
       "    (emb_layer_b): Embedding(4, 8)\n",
       "    (emb_layer_c): Embedding(4, 8)\n",
       "    (emb_layer_d): Embedding(4, 8)\n",
       "  )\n",
       "  (embed_dropout): Dropout(p=0.0, inplace=False)\n",
       "  (dense): Sequential(\n",
       "    (dense_layer_0): Sequential(\n",
       "      (0): Linear(in_features=33, out_features=16, bias=True)\n",
       "      (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "      (2): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (3): Dropout(p=0.5, inplace=False)\n",
       "    )\n",
       "    (dense_layer_1): Sequential(\n",
       "      (0): Linear(in_features=16, out_features=8, bias=True)\n",
       "      (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "      (2): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (3): Dropout(p=0.5, inplace=False)\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "deepdense"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.0000, -0.9949,  3.8273,  0.0000, -1.3889, -2.9641,  0.0000, -0.0000],\n",
       "        [ 3.9123, -0.0000, -0.0000,  1.9555, -1.3561,  1.7069, -0.0000,  0.9275],\n",
       "        [-0.0000, -0.0000,  0.0000, -0.0000,  0.0000, -1.6489, -0.0000, -1.4985],\n",
       "        [-1.2736,  0.0000, -1.2819,  2.1232,  0.0000,  2.2767, -0.0000,  3.5354],\n",
       "        [-0.1726, -0.0000, -1.3275, -0.0000, -1.3703,  0.0000, -0.0000, -1.4637]],\n",
       "       grad_fn=<MulBackward0>)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "deepdense(X_deep)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "###  3. DeepText\n",
    "\n",
    "The `DeepText` class within the `WideDeep` package is a standard and simple stack of LSTMs on top of word embeddings. You could also add a FC-Head on top of the LSTMs. The word embeddings can be pre-trained. In the future I aim to include some simple pretrained models so that the combination between text and images is fair.  \n",
    "\n",
    "*While I recommend using the `Wide` and `DeepDense` classes within this package when building the corresponding model components, it is very likely that the user will want to use custom text and image models. That is perfectly possible. Simply, build them and pass them as the corresponding parameters. Note that the custom models MUST return a last layer of activations (i.e. not the final prediction) so that  these activations are collected by WideDeep and combined accordingly. In  addition, the models MUST also contain an attribute `output_dim` with the size of these last layers of activations.*\n",
    "\n",
    "Let's have a look to the `DeepText` class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from pytorch_widedeep.models import DeepText"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "?DeepText"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_text = torch.cat((torch.zeros([5,1]), torch.empty(5, 4).random_(1,4)), axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "deeptext = DeepText(vocab_size=4, hidden_dim=4, n_layers=1, padding_idx=0, embed_dim=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DeepText(\n",
       "  (word_embed): Embedding(4, 4, padding_idx=0)\n",
       "  (rnn): LSTM(4, 4, batch_first=True)\n",
       ")"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "deeptext"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You could, if you wanted, add a Fully Connected Head (FC-Head) on top of it"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "deeptext = DeepText(vocab_size=4, hidden_dim=8, n_layers=1, padding_idx=0, embed_dim=4, head_layers=[8,4], \n",
    "                    head_batchnorm=True, head_dropout=[0.5, 0.5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DeepText(\n",
       "  (word_embed): Embedding(4, 4, padding_idx=0)\n",
       "  (rnn): LSTM(4, 8, batch_first=True)\n",
       "  (texthead): Sequential(\n",
       "    (dense_layer_0): Sequential(\n",
       "      (0): Linear(in_features=8, out_features=4, bias=True)\n",
       "      (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "      (2): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (3): Dropout(p=0.5, inplace=False)\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "deeptext"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that since the FC-Head will receive the activations from the last hidden layer of the stack of RNNs, the corresponding dimensions must be consistent."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "###  4. DeepImage\n",
    "\n",
    "The `DeepImage` class within the `WideDeep` package is either a pretrained ResNet (18, 34, or 50. Default is 18) or a stack of CNNs, to which one can add a FC-Head. If is a pretrained ResNet, you can chose how many layers you want to defrost deep into the network with the parameter `freeze`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pytorch_widedeep.models import DeepImage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "?DeepImage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_img = torch.rand((2,3,224,224))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "deepimage = DeepImage(head_layers=[512, 64, 8])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DeepImage(\n",
       "  (backbone): Sequential(\n",
       "    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
       "    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (2): ReLU(inplace=True)\n",
       "    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
       "    (4): Sequential(\n",
       "      (0): BasicBlock(\n",
       "        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (1): BasicBlock(\n",
       "        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "    )\n",
       "    (5): Sequential(\n",
       "      (0): BasicBlock(\n",
       "        (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (downsample): Sequential(\n",
       "          (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (1): BasicBlock(\n",
       "        (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "    )\n",
       "    (6): Sequential(\n",
       "      (0): BasicBlock(\n",
       "        (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (downsample): Sequential(\n",
       "          (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (1): BasicBlock(\n",
       "        (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "    )\n",
       "    (7): Sequential(\n",
       "      (0): BasicBlock(\n",
       "        (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (downsample): Sequential(\n",
       "          (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "          (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (1): BasicBlock(\n",
       "        (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "    )\n",
       "    (8): AdaptiveAvgPool2d(output_size=(1, 1))\n",
       "  )\n",
       "  (imagehead): Sequential(\n",
       "    (dense_layer_0): Sequential(\n",
       "      (0): Linear(in_features=512, out_features=64, bias=True)\n",
       "      (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "      (2): Dropout(p=0.0, inplace=False)\n",
       "    )\n",
       "    (dense_layer_1): Sequential(\n",
       "      (0): Linear(in_features=64, out_features=8, bias=True)\n",
       "      (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "      (2): Dropout(p=0.0, inplace=False)\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "deepimage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-2.2825e-03, -8.3100e-04, -8.8423e-04, -1.1084e-04,  8.8529e-02,\n",
       "         -5.1577e-04,  2.8343e-01, -1.7071e-03],\n",
       "        [-1.8486e-03, -8.5602e-04, -1.8552e-03,  3.6481e-01,  9.0812e-02,\n",
       "         -9.6603e-04,  3.9017e-01, -2.6355e-03]], grad_fn=<LeakyReluBackward1>)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "deepimage(X_img)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "if `pretrained=False` then a stack of 4 CNNs are used"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "deepimage = DeepImage(pretrained=False, head_layers=[512, 64, 8])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DeepImage(\n",
       "  (backbone): Sequential(\n",
       "    (0): Sequential(\n",
       "      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (1): BatchNorm2d(64, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)\n",
       "      (2): LeakyReLU(negative_slope=0.1, inplace=True)\n",
       "      (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "    )\n",
       "    (1): Sequential(\n",
       "      (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))\n",
       "      (1): BatchNorm2d(128, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)\n",
       "      (2): LeakyReLU(negative_slope=0.1, inplace=True)\n",
       "    )\n",
       "    (2): Sequential(\n",
       "      (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))\n",
       "      (1): BatchNorm2d(256, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)\n",
       "      (2): LeakyReLU(negative_slope=0.1, inplace=True)\n",
       "    )\n",
       "    (3): Sequential(\n",
       "      (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))\n",
       "      (1): BatchNorm2d(512, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)\n",
       "      (2): LeakyReLU(negative_slope=0.1, inplace=True)\n",
       "      (adaptiveavgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
       "    )\n",
       "  )\n",
       "  (imagehead): Sequential(\n",
       "    (dense_layer_0): Sequential(\n",
       "      (0): Linear(in_features=512, out_features=64, bias=True)\n",
       "      (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "      (2): Dropout(p=0.0, inplace=False)\n",
       "    )\n",
       "    (dense_layer_1): Sequential(\n",
       "      (0): Linear(in_features=64, out_features=8, bias=True)\n",
       "      (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "      (2): Dropout(p=0.0, inplace=False)\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "deepimage"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "###  5. deephead\n",
    "\n",
    "Note that I do not use uppercase here. This is because, by default, the `deephead` is not defined outside `WideDeep` as the rest of the components. \n",
    "\n",
    "When defining the `WideDeep` model there is a parameter called `head_layers` (and the corresponding `head_dropout`, and `head_batchnorm`) that define the FC-head on top of `DeeDense`, `DeepText` and `DeepImage`. \n",
    "\n",
    "Of course, you could also chose to define it yourself externally and pass it using the parameter `deephead`. Have a look"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pytorch_widedeep.models import WideDeep"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "?WideDeep"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
